import torch


def create_batch_info(data, edge_counter):
    """
    Compute batch information dictionary for a PyG Data object.
    If 'batch' or 'num_graphs' are not defined (e.g. for a single graph), they are set by default.
    """
    # Ensure num_nodes attribute exists.
    if not hasattr(data, 'num_nodes'):
        data.num_nodes = data.x.size(0)
    num_nodes = data.num_nodes

    # If data.batch is not provided, assume a single graph.
    if not hasattr(data, 'batch'):
        data.batch = torch.zeros(num_nodes, dtype=torch.long, device=data.x.device)
        data.num_graphs = 1
    else:
        if not hasattr(data, 'num_graphs'):
            data.num_graphs = int(data.batch.max().item()) + 1

    batch = data.batch
    batch_size = data.num_graphs

    # Count the number of nodes in each graph.
    unique, n_per_graph = torch.unique(batch, return_counts=True)
    n_batch = torch.zeros_like(batch, dtype=torch.float)
    for value, n in zip(unique, n_per_graph):
        n_batch[batch == value] = n.float()

    # Count the average number of edges per graph using the provided edge_counter.
    dummy = data.x.new_ones((num_nodes, 1))
    average_edges = edge_counter(dummy, data.edge_index, batch, batch_size)

    # Create coloring if not present.
    if not hasattr(data, 'coloring'):
        data.coloring = data.x.new_zeros(num_nodes, dtype=torch.long)
        for i in range(batch_size):
            idx = (batch == i).nonzero(as_tuple=False).squeeze()
            if idx.dim() == 0:
                idx = idx.unsqueeze(0)
            data.coloring[idx] = torch.arange(idx.size(0), device=data.x.device)
        data.coloring = data.coloring[:, None]
    n_colors = torch.max(data.coloring) + 1  # indexing starts at 0

    mask = torch.zeros(num_nodes, n_colors, dtype=torch.bool, device=data.x.device)
    for value, n in zip(unique, n_per_graph):
        mask[batch == value, :n] = True

    batch_info = {
        'num_nodes': num_nodes,
        'num_graphs': batch_size,
        'batch': batch,
        'n_per_graph': n_per_graph,
        'n_batch': n_batch[:, None, None].float(),
        'average_edges': average_edges[:, :, None],
        'coloring': data.coloring,
        'n_colors': n_colors,
        'mask': mask
    }
    return batch_info


def create_batch_info_dgl(g, edge_counter):
    # Use node features stored in g.ndata['feat']
    x = g.ndata['feat']
    num_nodes = x.size(0)
    
    # For edge_index, get source and destination nodes and stack them
    src, dst = g.edges()
    edge_index = torch.stack([src, dst], dim=0)
    
    # For a single graph, define batch as all zeros and num_graphs = 1.
    batch = torch.zeros(num_nodes, dtype=torch.long, device=x.device)
    num_graphs = 1
    
    # n_per_graph is simply num_nodes
    n_per_graph = torch.tensor([num_nodes], device=x.device)
    
    # n_batch is a tensor filled with the number of nodes
    n_batch = torch.full((num_nodes, 1, 1), num_nodes, device=x.device, dtype=torch.float)
    
    # Compute average_edges using your edge_counter (assuming it accepts edge_index as defined)
    average_edges = edge_counter(x.new_ones((num_nodes, 1)), edge_index, batch, num_graphs)
    
    # Create a unique coloring: assign each node a unique identifier
    coloring = torch.arange(num_nodes, device=x.device).unsqueeze(1)
    n_colors = num_nodes
    
    # Create a mask: here, for each node we mark its unique color as valid.
    mask = torch.eye(n_colors, dtype=torch.bool, device=x.device)
    
    batch_info = {
        'num_nodes': num_nodes,
        'num_graphs': num_graphs,
        'batch': batch,
        'n_per_graph': n_per_graph,
        'n_batch': n_batch,
        'average_edges': average_edges.unsqueeze(-1),
        'coloring': coloring,
        'n_colors': n_colors,
        'mask': mask
    }
    return batch_info


def map_x_to_u(data, batch_info):
    """ map the node features to the right row of the initial local context."""
    x = data.x
    u = x.new_zeros((data.num_nodes, batch_info['n_colors']))
    u.scatter_(1, data.coloring, 1)
    u = u[..., None]

    u_x = u.new_zeros((u.shape[0], u.shape[1], x.shape[1]))

    n_features = x.shape[1]
    coloring = batch_info['coloring']       # N x 1
    expanded_colors = coloring[..., None].expand(-1, -1, n_features)

    u_x = u_x.scatter_(dim=1, index=expanded_colors, src=x[:, None, :])

    u = torch.cat((u, u_x), dim=2)
    return u
